"""一些小工具"""

import torch


def recursive_to_device(data, device):
    """递归地将数据放到device上"""
    if isinstance(data, torch.Tensor):
        return data.to(device)
    elif isinstance(data, dict):
        return {k: recursive_to_device(v, device) for k, v in data.items()}
    elif isinstance(data, list):
        return [recursive_to_device(v, device) for v in data]
    else:
        return data
